{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"align-center\">\n",
    "<a href=\"https://oumi.ai/\"><img src=\"https://oumi.ai/docs/en/latest/_static/logo/header_logo.png\" height=\"200\"></a>\n",
    "\n",
    "[![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://oumi.ai/docs/en/latest/index.html)\n",
    "[![Discord](https://img.shields.io/discord/1286348126797430814?label=Discord)](https://discord.gg/oumi)\n",
    "[![GitHub Repo stars](https://img.shields.io/github/stars/oumi-ai/oumi)](https://github.com/oumi-ai/oumi)\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/oumi-ai/oumi/blob/main/configs/projects/halloumi/halloumi_inference_notebook.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
    "</div>\n",
    "\n",
    "👋 Welcome to Open Universal Machine Intelligence (Oumi)!\n",
    "\n",
    "🚀 Oumi is a fully open-source platform that streamlines the entire lifecycle of foundation models - from [data preparation](https://oumi.ai/docs/en/latest/resources/datasets/datasets.html) and [training](https://oumi.ai/docs/en/latest/user_guides/train/train.html) to [evaluation](https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html) and [deployment](https://oumi.ai/docs/en/latest/user_guides/launch/launch.html). Whether you're developing on a laptop, launching large scale experiments on a cluster, or deploying models in production, Oumi provides the tools and workflows you need.\n",
    "\n",
    "🤝 Make sure to join our [Discord community](https://discord.gg/oumi) to get help, share your experiences, and contribute to the project! If you are interested in joining one of the community's open-science efforts, check out our [open collaboration](https://oumi.ai/community) page.\n",
    "\n",
    "⭐ If you like Oumi and you would like to support it, please give it a star on [GitHub](https://github.com/oumi-ai/oumi)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# HallOumi Inference\n",
    "\n",
    "This notebook demonstrates how you can run inference locally for HallOumi 8B. For more details on HallOumi, please read our [GitHub documentation](https://github.com/oumi-ai/oumi/blob/main/configs/projects/halloumi/README.md) and our [blog post](https://oumi.ai/blog/posts/introducing-halloumi)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install Oumi, so that you can use our inference engines. You can find more detailed instructions about Oumi installation [here](https://oumi.ai/docs/en/latest/get_started/installation.html). If you're running this notebook on a CUDA-compatible GPU and want to use vLLM for inference, make sure to install the optional Oumi `[gpu]` dependencies.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install oumi\n",
    "# %pip install oumi[gpu]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the `nltk` library and download `punkt_tab`. These are needed for sentence splitting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install nltk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nltk\n",
    "\n",
    "nltk.download(\"punkt_tab\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helper Functions\n",
    "\n",
    "The following function is used to create the prompt for HallOumi from a `context` document, a `request` to a language model, and its corresponding `response`. HallOumi's objective is to determine whether the language model hallucinated, meaning that the `response` cannot be grounded to the provided `context`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nltk.tokenize import sent_tokenize\n",
    "\n",
    "\n",
    "def create_prompt(context: str, request: str, response: str) -> str:\n",
    "    \"\"\"Generates a prompt for the generative HallOumi model.\"\"\"\n",
    "\n",
    "    def _split_into_sentences(text: str) -> list[str]:\n",
    "        sentences = sent_tokenize(text.strip())\n",
    "        return [s.strip() for s in sentences if s.strip()]\n",
    "\n",
    "    def _annotate_sentences(sentences: list[str], annotation_char: str) -> str:\n",
    "        annotated_sentences = []\n",
    "        for idx, sentence in enumerate(sentences, start=1):\n",
    "            annotated_sentences.append(\n",
    "                f\"<|{annotation_char}{idx}|><{sentence}><end||{annotation_char}>\"\n",
    "            )\n",
    "        return \"\".join(annotated_sentences)\n",
    "\n",
    "    # Context: Split it into sentences and annotate them.\n",
    "    context_sentences = _split_into_sentences(context)\n",
    "    annotated_context_sentences = _annotate_sentences(context_sentences, \"s\")\n",
    "    annotated_context = f\"<|context|>{annotated_context_sentences}<end||context>\"\n",
    "\n",
    "    # Request: Annotate the request.\n",
    "    annotated_request = f\"<|request|><{request.strip()}><end||request>\"\n",
    "\n",
    "    # Response: Split it into sentences and annotate them.\n",
    "    response_sentences = _split_into_sentences(response)\n",
    "    annotated_response_sentences = _annotate_sentences(response_sentences, \"r\")\n",
    "    annotated_response = f\"<|response|>{annotated_response_sentences}<end||response>\"\n",
    "\n",
    "    # Combine all parts into the final prompt.\n",
    "    return f\"{annotated_context}{annotated_request}{annotated_response}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following function is used to extract a list of `Claim`s from HallOumi's response. The `Claim` class encapsulates the prediction (`supported`), the `rationale`, a list of `subclaims` that the claim consists of, and their corresponding `citations`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import contextlib\n",
    "from dataclasses import dataclass, field\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class Claim:\n",
    "    claim_id: int = -1\n",
    "    claim_string: str = \"\"\n",
    "    subclaims: list[str] = field(default_factory=list)\n",
    "    citations: list[int] = field(default_factory=list)\n",
    "    rationale: str = \"\"\n",
    "    supported: bool = True\n",
    "\n",
    "\n",
    "def get_claims_from_response(response: str) -> list[Claim]:\n",
    "    \"\"\"Extracts claims from the response string.\"\"\"\n",
    "\n",
    "    def _get_claim_id_from_subsegment(subsegment: str) -> int:\n",
    "        claim_id_part = subsegment.split(\"|\")[1]\n",
    "        claim_id_no_r = claim_id_part.lstrip(\"r\")\n",
    "        return int(claim_id_no_r)\n",
    "\n",
    "    def _get_claim_citations_from_subsegment(subsegment: str) -> list[int]:\n",
    "        citation_segments = subsegment.split(\",\")\n",
    "        citations = []\n",
    "        for citation_segment in citation_segments:\n",
    "            citation = citation_segment.replace(\"|\", \"\").replace(\"s\", \"\").strip()\n",
    "            if \"-\" in citation:\n",
    "                start, end = map(int, citation.split(\"-\"))\n",
    "                citations.extend(range(start, end + 1))\n",
    "            elif \"to\" in citation:\n",
    "                start, end = map(int, citation.split(\"to\"))\n",
    "                citations.extend(range(start, end + 1))\n",
    "            else:\n",
    "                with contextlib.suppress(ValueError):\n",
    "                    citation_int = int(citation)\n",
    "                    citations.append(citation_int)\n",
    "        return citations\n",
    "\n",
    "    def _get_claim_from_segment(segment: str) -> Claim:\n",
    "        claim_segments = segment.split(\"><\")\n",
    "        claim = Claim()\n",
    "        claim.claim_id = _get_claim_id_from_subsegment(claim_segments[0])\n",
    "        claim.claim_string = claim_segments[1]\n",
    "\n",
    "        subclaims = []\n",
    "        claim_progress_index = 3  # start parsing subclaims from index 3\n",
    "        for i in range(claim_progress_index, len(claim_segments)):\n",
    "            subsegment = claim_segments[i]\n",
    "            if subsegment.startswith(\"end||subclaims\"):\n",
    "                claim_progress_index = i + 1\n",
    "                break\n",
    "            subclaims.append(subsegment)\n",
    "\n",
    "        citation_index = -1\n",
    "        rationale_index = -1\n",
    "        label_index = -1\n",
    "\n",
    "        for i in range(claim_progress_index, len(claim_segments)):\n",
    "            subsegment = claim_segments[i]\n",
    "            if subsegment.startswith(\"|cite|\"):\n",
    "                citation_index = i + 1\n",
    "            elif subsegment.startswith(\"|explain|\"):\n",
    "                rationale_index = i + 1\n",
    "            elif subsegment.startswith(\"|supported|\") or subsegment.startswith(\n",
    "                \"|unsupported|\"\n",
    "            ):\n",
    "                label_index = i\n",
    "\n",
    "        claim.subclaims = subclaims\n",
    "        claim.citations = (\n",
    "            _get_claim_citations_from_subsegment(claim_segments[citation_index])\n",
    "            if citation_index != -1\n",
    "            else []\n",
    "        )\n",
    "        claim.rationale = (\n",
    "            claim_segments[rationale_index] if rationale_index != -1 else \"\"\n",
    "        )\n",
    "        claim.supported = (\n",
    "            claim_segments[label_index].startswith(\"|supported|\")\n",
    "            if label_index != -1\n",
    "            else True\n",
    "        )\n",
    "        return claim\n",
    "\n",
    "    segments = response.split(\"<end||r>\")\n",
    "    claims = []\n",
    "    for segment in segments:\n",
    "        if segment.strip():\n",
    "            claim = _get_claim_from_segment(segment)\n",
    "            claims.append(claim)\n",
    "    return claims"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference Walkthrough"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset\n",
    "\n",
    "Let's start by defining a toy dataset, where each example consists of a `context` document, a `request` to the language model, and the model's `reponse`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "toy_dataset = [\n",
    "    {\n",
    "        \"context\": \"Today is a sunny day. The weather is nice.\",\n",
    "        \"request\": \"What is the weather like today?\",\n",
    "        \"response\": \"The weather is sunny.\",\n",
    "    },\n",
    "    {\n",
    "        \"context\": \"James is a software engineer. He works at a tech company.\",\n",
    "        \"request\": \"What does James do for a living?\",\n",
    "        \"response\": \"He is a hardware engineer. He loves his job.\",\n",
    "    },\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then convert these examples to a list of prompts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompts = []\n",
    "for example in toy_dataset:\n",
    "    prompt = create_prompt(\n",
    "        context=example[\"context\"],\n",
    "        request=example[\"request\"],\n",
    "        response=example[\"response\"],\n",
    "    )\n",
    "    prompts.append(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running Inference\n",
    "\n",
    "Next step is to instantiate an inference config. You can use the remote config (`remote_config_str`) below to call our API. If you want to download the model and run inference locally in your machine, use the local config (`local_config_str`) instead. If you do NOT have a GPU in your local machine, you will need to use the `NATIVE` engine and the inference will be really slow. However, if you have a CUDA-compatible GPU, set the engine to `VLLM` instead, to take full advantage of the GPU and speed up your inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from oumi.core.configs import InferenceConfig\n",
    "\n",
    "remote_config_str = \"\"\"\n",
    "model:\n",
    "    model_name: \"halloumi\"\n",
    "\n",
    "generation:\n",
    "    max_new_tokens: 8192\n",
    "    temperature: 0.0\n",
    "\n",
    "remote_params:\n",
    "    api_url: \"https://api.oumi.ai/chat/completions\"\n",
    "    max_retries: 3\n",
    "    connection_timeout: 300\n",
    "\n",
    "engine: REMOTE\n",
    "\"\"\"\n",
    "\n",
    "local_config_str = \"\"\"\n",
    "model:\n",
    "    model_name: \"oumi-ai/HallOumi-8B\"\n",
    "    model_max_length: 8192\n",
    "    trust_remote_code: true\n",
    "\n",
    "generation:\n",
    "    max_new_tokens: 8192\n",
    "    temperature: 0.0\n",
    "\n",
    "engine: NATIVE  # Set to VLLM, if you have a CUDA-compatible GPU.\n",
    "\"\"\"\n",
    "\n",
    "inference_config = InferenceConfig.from_str(local_config_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using this inference config, run inference on the prompts, as shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 04-03 21:30:36 [__init__.py:256] Automatically detected platform cpu.\n",
      "[2025-04-03 21:30:37,123][oumi][rank0][pid:3776][MainThread][INFO]][models.py:208] Building model using device_map: auto (DeviceRankInfo(world_size=1, rank=0, local_world_size=1, local_rank=0))...\n",
      "[2025-04-03 21:30:37,633][oumi][rank0][pid:3776][MainThread][INFO]][models.py:276] Using model class: <class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'> to instantiate model.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "20017e782b034770bd712bafc65e693b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:accelerate.big_modeling:Some parameters are on the meta device because they were offloaded to the disk.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-04-03 21:30:48,103][oumi][rank0][pid:3776][MainThread][INFO]][models.py:482] Using the model's built-in chat template for model 'oumi-ai/HallOumi-8B'.\n",
      "[2025-04-03 21:30:48,119][oumi][rank0][pid:3776][MainThread][INFO]][native_text_inference_engine.py:140] Setting EOS token id to `128009`\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating Model Responses: 100%|██████████| 2/2 [33:43<00:00, 1011.80s/it]\n"
     ]
    }
   ],
   "source": [
    "from oumi import infer\n",
    "\n",
    "inference_results = infer(\n",
    "    config=inference_config,\n",
    "    inputs=prompts,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inspecting the results\n",
    "\n",
    "Once inference completes, the last step is to iterate on the inference results: get the responses and extract the predictions, sub-claims, citations, and rationales."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[example=0, claim=0]: `The weather is sunny.`\n",
      "  - Supported? True\n",
      "  - Sub-claims: ['The current weather conditions are being described.', 'The description of the weather is that it is sunny.']\n",
      "  - Citations: [1]\n",
      "  - Rationale: The first sentence explicitly states that \"Today is a sunny day\", which directly supports the claim that the weather is sunny.\n",
      "\n",
      "[example=1, claim=0]: `He is a hardware engineer. `\n",
      "  - Supported? False\n",
      "  - Sub-claims: ['James works in the field of engineering.', 'James specifically works with hardware.']\n",
      "  - Citations: [1]\n",
      "  - Rationale: The document actually states that James is a software engineer, not a hardware engineer.\n",
      "\n",
      "[example=1, claim=1]: `He loves his job.`\n",
      "  - Supported? False\n",
      "  - Sub-claims: ['James has a positive sentiment towards his job.', 'James enjoys his work.']\n",
      "  - Citations: []\n",
      "  - Rationale: There is no information provided in the context about James's feelings towards his job.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for result_index, result in enumerate(inference_results):\n",
    "    # The model's response is the last message of the result (a `Conversation` object).\n",
    "    response = str(result.last_message().content)\n",
    "\n",
    "    claims = get_claims_from_response(response)\n",
    "    for claim_index, claim in enumerate(claims):\n",
    "        print(f\"[example={result_index}, claim={claim_index}]: `{claim.claim_string}`\")\n",
    "        print(f\"  - Supported? {claim.supported}\")\n",
    "        print(f\"  - Sub-claims: {claim.subclaims}\")\n",
    "        print(f\"  - Citations: {claim.citations}\")\n",
    "        print(f\"  - Rationale: {claim.rationale}\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "oumi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
